"""
Download AWS IP ranges and generate Rust code for CloudFront networks.
"""

import ipaddress
import json
import sys
import urllib.request
from pathlib import Path


AWS_IP_RANGES_URL = "https://ip-ranges.amazonaws.com/ip-ranges.json"


def download_ip_ranges():
    """Download and parse AWS IP ranges JSON."""
    try:
        with urllib.request.urlopen(AWS_IP_RANGES_URL) as response:
            return json.loads(response.read())
    except Exception as e:
        print(f"Error downloading AWS IP ranges: {e}", file=sys.stderr)
        sys.exit(1)


def parse_ipv4_network(cidr):
    """Parse an IPv4 CIDR string into a network object."""
    try:
        return ipaddress.IPv4Network(cidr)
    except Exception as e:
        print(f"Error parsing IPv4 CIDR {cidr}: {e}", file=sys.stderr)
        sys.exit(1)


def parse_ipv6_network(cidr):
    """Parse an IPv6 CIDR string into a network object."""
    try:
        return ipaddress.IPv6Network(cidr)
    except Exception as e:
        print(f"Error parsing IPv6 CIDR {cidr}: {e}", file=sys.stderr)
        sys.exit(1)


def filter_cloudfront(data):
    """Filter to only CloudFront IP ranges and parse them."""
    ipv4_networks = [
        parse_ipv4_network(p["ip_prefix"])
        for p in data.get("prefixes", [])
        if p.get("service") == "CLOUDFRONT"
    ]

    ipv6_networks = [
        parse_ipv6_network(p["ipv6_prefix"])
        for p in data.get("ipv6_prefixes", [])
        if p.get("service") == "CLOUDFRONT"
    ]

    # Sort networks for consistent ordering
    ipv4_networks.sort()
    ipv6_networks.sort()

    return ipv4_networks, ipv6_networks


def generate_ipv4_line(network):
    """Generate Rust code for an IPv4 network."""
    octets = network.network_address.packed
    prefix_len = network.prefixlen
    return f"    IpNetwork::V4(Ipv4Network::new_checked(Ipv4Addr::new({octets[0]}, {octets[1]}, {octets[2]}, {octets[3]}), {prefix_len}).unwrap()),"


def generate_ipv6_line(network):
    """Generate Rust code for an IPv6 network."""
    segments = [
        int.from_bytes(network.network_address.packed[i : i + 2], "big")
        for i in range(0, 16, 2)
    ]
    prefix_len = network.prefixlen
    seg_str = ", ".join(f"{s:#x}" for s in segments)
    return f"    IpNetwork::V6(Ipv6Network::new_checked(Ipv6Addr::new({seg_str}), {prefix_len}).unwrap()),"


def generate_rust_code(ipv4_networks, ipv6_networks):
    """Generate the complete Rust source code."""
    lines = [
        "//! AUTO-GENERATED by update-data.py - DO NOT EDIT MANUALLY",
        "",
        "use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network};",
        "use std::net::{Ipv4Addr, Ipv6Addr};",
        "",
        "/// CloudFront IP ranges from AWS.",
        "#[rustfmt::skip]",
        "pub const CLOUDFRONT_NETWORKS: &[IpNetwork] = &[",
    ]

    for network in ipv4_networks:
        lines.append(generate_ipv4_line(network))

    for network in ipv6_networks:
        lines.append(generate_ipv6_line(network))

    lines.append("];")
    lines.append("")  # Final newline

    return "\n".join(lines)


def main():
    """Main entry point."""
    print("Downloading AWS IP ranges...")
    data = download_ip_ranges()

    print("Filtering CloudFront prefixes...")
    ipv4_networks, ipv6_networks = filter_cloudfront(data)

    total = len(ipv4_networks) + len(ipv6_networks)
    if total == 0:
        print("Error: No CloudFront networks found!", file=sys.stderr)
        sys.exit(1)

    print(f"Found {len(ipv4_networks)} IPv4 and {len(ipv6_networks)} IPv6 CloudFront networks")

    print("Generating Rust code...")
    rust_code = generate_rust_code(ipv4_networks, ipv6_networks)

    output_path = Path(__file__).parent / "src" / "cloudfront.rs"
    print(f"Writing to {output_path}...")
    output_path.write_text(rust_code)

    print("Done!")


if __name__ == "__main__":
    main()
